clear all; close all; clc

% Before Running, please read it carefully:
% 1. On line 120, add full path to access county-level_data (e.g, C:\Users\...\Reproducible_Version\county-level_data')

% 2. In the folder Reproducible_Version, there will be two subfolders:
% result_figures_reproducible and result_files_reproducible. These are
% empty folders that will store the results if you run this code. You can access the results already generated in
% the folders Results \ result_figures and result_files.

% Load vector with county names and weeks
load vec_county.mat;

% Load weeks
table_weeks = readtable('weeks_new_table.xlsx');
weeks       = char(table2array(table_weeks));
clear table_weeks

% Load Rt data and event weeks
table_events = readtable('data_events_omicron_shifted.xlsx');

% Vector with digital/predictor variable names
vec_var1{1}  = 'gt_covid';
vec_var1{2}  = 'gt_covid_19';
vec_var1{3}  = 'gt_howLongDoesCovidLast';
vec_var1{4}  = 'gt_covidSymptoms';
vec_var1{5}  = 'gt_covid_19Who';
vec_var1{6}  = 'up2date';
vec_var1{7}  = 'twitter_state';
vec_var1{8}  = 'gt_fever';
vec_var1{9}  = 'gt_chestPain';
vec_var1{10} = 'gt_afterCovidVaccine';
vec_var1{11} = 'gt_sideEffectsOfVaccine';
vec_var1{12} = 'gt_effectsOfCovidVaccine';
vec_var1{13}  = 'JHU_cases';
%vec_var1{10} = 'JHU_cases_state';
%vec_var1{11} = 'JHU_deaths_state';
vec_var1{14} = 'JHU_deaths';

%% Choose variables and parameters
EW_tol_weeks = 6;                   % number of weeks we tolerate to call it at EW... digital activation preceding event with more than the tolerance is considered a false alarm...
alphamin     = 2;                   % choose minimum number of positive derivative events for the loop
strategy     = 'fast';              % choose strategy fast vs slow: fast picks up the minimum optimal alpha and threshold values (bottom-left of a rectangle), while slow picks up the maximum optimal values
window_gap   = 5;                   % max gap in weeks
vec_bg       = [0.1:0.05:0.9];      % define vector of percentage threshold values for digital data

vec_alpha  =  alphamin:window_gap; % tolerance for alpha>1. If time window has tol_alpha or more points of increase (alpha>1), mark as an event.

%% CODE STARTS HERE: LOOP ON DIGITAL/PREDICTOR VARIABLES STARTS HERE

for ind_var1 = 1:length(vec_var1)
    var1     = vec_var1{ind_var1};

    % if statement to change final week if signal is up2date or twitter (state) - due to lack of
    % data
    if strcmp('up2date',var1)
        week_final = 61;
    else
        week_final = length(weeks);
    end


%%  START LOOP ON COUNTIES
for ind_county = 1:length(vec_county)
    ind_county, var1

    % Find first week of training set
    startweek_train             = 1;

    % Find events for the county
    aux_ind_county              = contains(string(table_events{:,2}),vec_county(ind_county));
    table_selected_events       = table_events(aux_ind_county,:);

    % start loop on training waves
    for ind_last_train_wave = 1:size(table_selected_events,1)-1  % if the given county had N waves, we will train from wave 1 to X, where X=1,2,...N-1 and test on wave X+1

        % week index of wave 1 come from parag activation
        aux_ind_wave1                      = datetime(weeks) == table_selected_events{1,4};
        week_wave1                         = find(aux_ind_wave1);

        % Build vector of training waves
        vec_week_train_event               = week_wave1; % will input in fun train

        % week index of next train waves
        for ind_train_wave = 2:ind_last_train_wave
            aux_ind_wave                      = datetime(weeks)==table_selected_events{ind_train_wave,4}; % column index 4 indicates starting dates for waves
            week_wave                         = find(aux_ind_wave);
            vec_week_train_event              = [vec_week_train_event week_wave];
        end

        % Find last week of training set
        aux_ind_endweek_train              = datetime(weeks)==table_selected_events{ind_last_train_wave,5}; % colum index 5 indicates end dates for waves
        endweek_train                      = find(aux_ind_endweek_train);

        % Find index of test week
        aux_ind_week_test                  = datetime(weeks)==table_selected_events{ind_last_train_wave+1,4}; %
        vec_week_test_event                = find(aux_ind_week_test);

        %% BUILD TIME WINDOWS FOR TRAINING AND TEST DATA SETS

        % Define training windows vector
        vec_training_time = startweek_train:endweek_train;             % define training weeks
        vec_test_time     = endweek_train+1:week_final;                % define test weeks

        % Define sliding windows for training dataset
        for j0=1:length(vec_training_time)-window_gap
            windows_training{j0} = vec_training_time(j0):vec_training_time(j0) + window_gap;
        end

        % Define sliding windows for test dataset
        for jj0=1:length(vec_test_time)-window_gap
            windows_test{jj0} = vec_test_time(jj0):vec_test_time(jj0) + window_gap;
        end

        % If we can generate training/test windows and if the test
        % events fall into the available dataset, proceed
        if exist('windows_training','var') == 1 && exist('windows_test','var') == 1 && week_final >= vec_week_test_event

            %% Load county data
            folder                 = 'C:\Users\ch222818\Dropbox\Reproducible_Version_County_Level\county-level_data';
            fileName_county        = [vec_county{ind_county} '_preprocessed.csv'];
            fullFileName           = fullfile(folder, fileName_county);

            table_data             = readtable(fullFileName);                                     % read table of population data
            var1_timeseries        = table2array(table_data(1:week_final,var1));                  % get double array for digital data var1.


            clear aux_ind_county aux_ind_week_test aux_ind_lambda  aux_ind_county aux_ind_week_train  fileName1 fileName2 fileName_county j0 jj0

            %% part 1: TRAINING STEP - FIXED EVENT AND CHOOSE OPTIMAL THRESHOLD

            % Call training function
            [matrix_ACC_train,matrix_PPV_train,matrix_NPV_train,matrix_TP_train,matrix_FP_train,matrix_TN_train,matrix_FN_train,activation_weeks_var1_train,activation_weeks_var2_train] = fun_train(vec_training_time,windows_training,vec_bg,var1_timeseries,vec_week_train_event,vec_alpha);


            if ~isempty(matrix_ACC_train)& find(~isnan(matrix_PPV_train)==1)>0 % proceed if there was a selection window and at least one TP or FP across training simulations (otherwise PPV matrix would be solely of NaN given by 0/0 division - we discard that case)


                % Choose optimal thresholds
                ind_PPV = matrix_PPV_train == max(max(matrix_PPV_train)); % get indexes where PPV is max
                ind_NPV = matrix_NPV_train == max(max(matrix_NPV_train)); % get indexes where NPV is max
                ind_ACC = matrix_ACC_train == max(max(matrix_ACC_train)); % get indexes where ACC is max

                ind_sum             = ind_PPV + ind_NPV + ind_ACC;    % find indexes where ACC, PPV and NPV maximum are achieved. Entries where ind_summ is equal to 3.
                [ind_max1,ind_max2] = find(ind_sum==3);               % find indexes where maximum for ACC, PPV and NPV are achieved. This command will yield a list of x and y coordinates.
                ind_max_sum         = ind_max1 + ind_max2;            % Since there might be several candidates, choose index for which the sum of indexes is minimum. This is one option, others might be possible as well

                if contains(strategy, 'fast')
                    [~,ww]               = min(ind_max_sum);               % get location where minimum is achieved
                else if contains(strategy, 'slow')
                        [~,ww]              = max(ind_max_sum);              % get location where maximum is achieved
                end
                end

                if ~isempty(ww)
                    w_sig                    = ~isempty(ww); % w_sig=1 means ACC, PPV and NPV were selected for training.
                    alpha_opt                = vec_alpha(ind_max1(ww(1)));    % pick corresponding bg value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.
                    bg_opt                   = vec_bg(ind_max2(ww(1)));       % pick corresponding bt value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.

                else if isempty(ww)
                        w_sig                  = ~isempty(ww);               % w_sig=0 means ACC, PPV and NPV were not selected for training.
                        [ind_max1,ind_max2]    = find(ind_PPV==1);           % if there are no bg and bt maximizing ACC, NPV and PPV simultaneously, find those that only maximize PPV
                        ind_max_sum            = ind_max1 + ind_max2;        % Since there might be several candidates, choose index for which the sum of indexes is minimum. This is one option, others might be possible as well

                        if contains(strategy, 'fast')
                            [~,ww]               = min(ind_max_sum);               % get location where minimum is achieved
                        else if contains(strategy, 'slow')
                                [~,ww]              = max(ind_max_sum);              % get location where maximum is achieved
                        end
                        end
                        alpha_opt                = vec_alpha(ind_max1(ww(1)));    % pick corresponding bg value. If there are more than 1 ww value (two parameter choices with the same minimum/maxmimum sum), pick the first.
                        bg_opt                   = vec_bg(ind_max2(ww(1)));       % pick corresponding bt value. If there are more than 1 ww value (two parameter choices with the same minimum/maximum sum), pick the first.

                end
                end
                bg_opt, alpha_opt
                %% PART 2: Test on second wave

                [select_windows_th_act,select_windows_deriv_act,windows_TP_test,windows_FP_test,windows_TN_test,windows_FN_test,num_TP_test, num_FP_test, num_TN_test, num_FN_test,ACC_test,PPV_test,NPV_test,activation_weeks_var1_test,activation_weeks_var2_test] = fun_test(alpha_opt,bg_opt,vec_training_time,windows_test,var1_timeseries,vec_week_test_event);

                clear ind_ACC ww ind_sum ind_max1 ind_max2 ind_max_sum ind_NPV ind_PPV ind_train_wave

                %% PLOT RESULTS
                h= figure(1);
                set(h, 'Visible', 'off')
                subplot(2,1,1)
                plot(var1_timeseries,'b.-'); hold on
                xline(endweek_train)
                line([endweek_train week_final],[bg_opt*max(var1_timeseries(1:endweek_train)) bg_opt*max(var1_timeseries(1:endweek_train))],'Color','k')
                ylabel([var1])

                for i1=1:length(activation_weeks_var1_test)
                    aux_activation_weeks_var1_test(i1)= activation_weeks_var1_test{i1};
                end
                unique_activation_weeks_var1_test = unique(aux_activation_weeks_var1_test);

                if ~isnan(unique_activation_weeks_var1_test)
                    plot(unique_activation_weeks_var1_test,var1_timeseries(unique_activation_weeks_var1_test),'bo');
                end
                title(['alpha_{opt} =  ' num2str(alpha_opt) '. First out of sample activation: week ' num2str(unique_activation_weeks_var1_test(1)) '.'])

                subplot(2,1,2)
                cases_timeseries = table2array(table_data(1:week_final,'JHU_cases'));
                plot(cases_timeseries,'r.-'); hold on
                title(['Rt activation: week ' num2str(vec_week_test_event) '.'])
                xline(endweek_train)
                ylabel(['JHU cases'])

                if vec_week_test_event<=week_final
                    plot(vec_week_test_event,cases_timeseries(vec_week_test_event),'rs'); hold on
                end

                sgtitle([ vec_county{ind_county} ': training up to wave ' num2str(ind_last_train_wave) ' and testing on wave ' num2str(ind_last_train_wave+1) '.'],'Interpreter','none')
                xlabel('weeks')

                %% SAVE FILES AND FIGURES
                figurename_fig = ['\result_figures_reproducible\fig\Figure_omicron_parag_prob_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_county_' num2str(ind_county) '.fig'];
                figurename_png = ['\result_figures_reproducible\png\Figure_omicron_parag_prob_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_county_' num2str(ind_county) '.png'];


                saveas(h,[pwd figurename_fig])
                saveas(h,[pwd figurename_png])


                close all % close figure after saving

                % File name in mat format
                filename_mat = ['result_files_reproducible/Results_omicron_shifted_parag_prob_training_upto_wave_' num2str(ind_last_train_wave) '_test_wave_' num2str(ind_last_train_wave+1) '_alphamin_' num2str(alphamin) '_strategy_' strategy '_signal_' var1 '_ind_county_' num2str(ind_county) '.mat'];

                % Save file with results
                save(filename_mat)
            end % end of if statement
        end     % end of if statement

        clearvars -except table_events table_pop table_Rt var1 vec_alpha vec_bg vec_county vec_var1 week_final weeks window_gap strategy alphamin ind_county table_selected_events  startweek_train EW_tol_weeks
    end % end loop on waves
end % end loop on counties
clear var1 week_final

end % end loop on predictor